#!/usr/bin/python
"""A script to plot the output of rsvd_train.

Plots the training and probe error using matplotlib.
Reads input from stdin:

Usage: 
$ rsvd_train --probe probe.arr train.arr x x x | rsvd_lc
"""

import sys
import re
import pylab as pl
import time

rec = re.compile(r"^(\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)$")

class Curve(object):

    def __init__(self, ax, lw = 1):
        self.curve, = ax.plot([], [], animated=True, lw = lw)
        self.xdata, self.ydata = [], []

    def update(self, x, y):
        self.xdata.append(x)
        self.ydata.append(y)
        self.curve.set_data(self.xdata, self.ydata)

def main():
    fig = pl.figure()
    ax = fig.add_subplot(111)
    ax.set_ylim(0.8, 1.2)
    ax.set_xlim(0, 10)

    traincurve = Curve(ax)
    probecurve = Curve(ax)
    
    pl.ylabel("RMSE")
    pl.xlabel("epochs")
    pl.legend(["train error","probe error"])
    
    
    def run(*args):
        print "rsvd learn curve reading from stdin..."
        while 1:
            line = sys.stdin.readline()
            if line == None:
                break
            print line,
            sys.stdout.flush()
            m = rec.match(line)
            if m != None:                
                epoch = int(m.groups()[0])
                trainerr = float(m.groups()[1])
                probeerr = float(m.groups()[2])
                t = float(m.groups()[3])
                
                xmin, xmax = ax.get_xlim()
                if epoch >= xmax:
                    ax.set_xlim(xmin, 2*xmax)
                    fig.canvas.draw()
                    background = fig.canvas.copy_from_bbox(ax.bbox)

                traincurve.update(epoch,trainerr)
                probecurve.update(epoch,probeerr)

                # just draw the animated artist
                ax.draw_artist(traincurve.curve)
                ax.draw_artist(probecurve.curve)
                # just redraw the axes rectangle
                fig.canvas.blit(ax.bbox)
        

    manager = pl.get_current_fig_manager()
    manager.window.after(100, run)
    pl.show()


if __name__ == "__main__":
    main()
